Attention Mechanism
Attention mechanisms address challenges in traditional neural network models like CNNs and RNNs, which require fixed input sizes. They offer a flexible approach to handling inputs of varying size and content, such as long text sequences. This flexibility is achieved through mechanisms that enable dynamic focus on different parts of the input.
Database and Queries Analogy
- Database Model: Databases use keys and values, where queries retrieve values based on keys. This concept is analogous to neural networks where queries fetch relevant information from a set of data (keys and values).
- Mathematical Formulation: Queries () in neural networks fetch values based on key similarity: where are the attention weights, indicating the importance of each value based on the query and key .
Normalization of Weights
- Softmax Function: Normalizes attention weights to ensure they sum to 1 and remain nonnegative:
Attention Pooling by Similarity
- Nadaraya-Watson Estimator: Utilizes a similarity kernel to relate queries to keys, demonstrating a precursor to modern attention mechanisms:
- Common Kernels: Gaussian, Boxcar, and Epanechikov kernels illustrate different approaches to calculating attention weights based on the distance between queries and keys.
Attention Scoring Functions
- Dot Product Attention: Simplifies the computation of attention weights using the dot product of queries and keys:
- Additive Attention: Suitable for differing dimensions of queries and keys, involves a combination of transformations and non-linear functions:
The Bahdanau Attention Mechanism
- Dynamic Context Variable: Updates the context variable at each decoding step, allowing the model to focus on different parts of the input sequence dynamically.
- Mathematical Formulation: The context variable is computed as a weighted sum of all encoder states, adapting based on the decoder's needs:
Attention Mechanism in Seq2Seq Translation
The PyTorch tutorial on seq2seq translation introduces an attention mechanism to enhance the translation model from French to English. Here's a concise explanation of how attention is applied:
Basic Model Structure
- Seq2Seq Framework: Consists of an encoder and a decoder, both implemented with RNNs. The encoder processes the input sequence into a context vector, which the decoder uses to produce the output sequence.
Code Example
class AttnDecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
super(AttnDecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.max_length = max_length
self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
self.dropout = nn.Dropout(self.dropout_p)
self.gru = nn.GRU(self.hidden_size, self.hidden_size)
self.out = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input, hidden, encoder_outputs):
attn_weights = F.softmax(self.attn(torch.cat((input[0], hidden[0]), 1)), dim=1)
attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))
output = torch.cat((input[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
output = F.relu(output)
output, hidden = self.gru(output, hidden)
output = F.log_softmax(self.out(output[0]), dim=1)
return output, hidden, attn_weights
Reference and Useful Links
- 11. Attention Mechanisms and Transformers — Dive into Deep Learning 1.0.3 documentation
- Attention Is All You Need
- Attention in transformers, visually explained | Chapter 6, Deep Learning - YouTube
- Explainable AI: Visualizing Attention in Transformers - Comet
- GitHub - jessevig/bertviz: BertViz: Visualize Attention in NLP Models (BERT, GPT2, BART, etc.)